library(R.matlab)
library(ggplot2)
library(RColorBrewer)
library(reshape2)
library(multcomp)
library(nlme)
library(lme4)
library(ISLR)
library(scales)
library(svglite)
library(ez)
library(Hmisc)
library(dplyr)

rm(list=ls())

## --- User-defined parameters ---

savePlots = 0
plotAll = 0
genPlots = c(0, # 1 - 2A) WM Task Error Rates: No INT conditions only
             0, # 2 - 2B) WM Task Error Rates: AT and AS INT relative to No INT
             0, # 3 - 3) INT Error Rates by WM task (group = INT task)
             0, # 4 - 5B) Interfering task ERPs (int S1)
             0, # 5 - 5C) ERP peaks: Domain x Intervening task interaction plot
             0) # 6 - 6B and 6D) Alpha power timecourses (averaged across WM domain)
# Other plotting locations:
#   Pupillometry figures: VAST_wholeTrialPupilAnalysis.R
#   All-ERP figure (5A): plot_all_ERPs.R

bPal = brewer.pal(4,'Dark2')

# # Primary-ish color scheme (complementary = no modality or domain similarity)
# condColors_INT = c('#808080', '#71bf6e', '#ff9933')
# condColors_WM = c('#71bf6e', '#ff9933', '#5889b0', '#e94849')

# Separate color schemes for WM and INT
condColors_INT = c('#8c8c8c', '#196619', '#63459b') # grey (no INT), green (temporal), purple (spatial)
condColors_WM = c('#802b00', '#cc4400', '#003399', '#3377ff') # Two-tone (orange family for aud, blue for vis)
condColors_WM_light = c('#e64d00', '#ff7733', '#0055ff', '#99bbff') # 20% lighter
condColors_modality = c('#a83800', '#004de6') # Averages within orange, blue families
condColors_domain = c('#404040', '#737373') # Darker, lighter grey to match other aesthetics in a modality-agnostic way


## --- Other setup ---

numPlots = length(genPlots)
if (plotAll == 1) {
  genPlots = rep(1,numPlots)
}

# Set paths based on current machine
compname = Sys.info()
compname = compname[[4]]
if (compname == "DESKTOP-BHR0AU7") {
  loadPath = 'E:/ANL/Experiments/RESULTS/VAST/DATAforR/'
  imgSaveDir = 'E:/ANL/Experiments/RESULTS/VAST/FIGURES/ms/'
} else if (compname == "MSI") {
  loadPath = 'F:/ANL/RESULTS/VAST/DATAforR/'
  imgSaveDir = 'F:/ANL/RESULTS/VAST/FIGURES/ms/'
} else if (compname == "UMN202302853") {
  loadPath = 'D:/ANL/RESULTS/VAST/DATAforR/'
  imgSaveDir = 'D:/ANL/RESULTS/VAST/FIGURES/ms/'
} else if (compname == "UMN202302892") {
  loadPath = 'E:/ANL/RESULTS/VAST/DATAforR/'
  imgSaveDir = 'E:/ANL/RESULTS/VAST/FIGURES/ms/'
} else {
  stop('set directories for this machine')
}

## --- Load and format the error rate data ---

errorData = readMat(paste(loadPath,'errorRates.mat',sep=""))

mem_m = unlist(errorData$conds.mod)
mem_d = unlist(errorData$conds.dom)
mem = interaction(mem_m, mem_d)
int = unlist(errorData$conds.inter)

# Average data
error_df = data.frame('mod'=mem_m, 'dom'=mem_d, 'mem'=mem, 'int'=int, 'error'=errorData$avg.errorRates,
                      'errorSEM'=errorData$error.SEM, 'error_int'=errorData$avg.errorRates.inter,
                      'error_intSEM'=errorData$error.SEM.inter, 'error_comb'=errorData$avg.errorRates.comb,
                      'error_combSEM'=errorData$error.SEM.comb)
error_df$mod = factor(error_df$mod)
error_df$dom = factor(error_df$dom)
error_df$dom = factor(error_df$dom, levels(error_df$dom)[c(2,1)]) # temporal before spatial
error_df$mem = factor(error_df$mem)
error_df$mem = factor(error_df$mem, levels(error_df$mem)[c(3,1,4,2)]) # re-order levels
error_df$int = factor(error_df$int)
error_df$int = factor(error_df$int, levels(error_df$int)[c(3,2,1)]) 

# Individual subjects data
nConds = nrow(errorData$errorRates)
nSs = ncol(errorData$errorRates)
err_indss = c(errorData$errorRates) # vectorizes matrix (subjects are COLUMNS!)
err_int_indss = c(errorData$errorRates.inter)
err_comb_indss = c(errorData$errorRates.comb)

error_df_indss = data.frame('subID'=rep(1:nSs,each=nConds), 'mod'=rep(mem_m,nSs), 'dom'=rep(mem_d,nSs), 
                            'mem'=rep(mem,nSs), 'int'=rep(int,nSs), 'err_mem'=err_indss, 
                            'err_int'=err_int_indss, 'err_comb'=err_comb_indss)
error_df_indss$subID = factor(error_df_indss$subID)
error_df_indss$mod = factor(error_df_indss$mod)
error_df_indss$dom = factor(error_df_indss$dom)
error_df_indss$dom = factor(error_df_indss$dom, levels(error_df_indss$dom)[c(2,1)]) # temporal before spatial
error_df_indss$mem = factor(error_df_indss$mem)
error_df_indss$mem = factor(error_df_indss$mem, levels(error_df_indss$mem)[c(3,1,4,2)]) # re-order levels
error_df_indss$int = factor(error_df_indss$int)
error_df_indss$int = factor(error_df_indss$int, levels(error_df_indss$int)[c(3,2,1)]) 


## --- Load and format the pupillometry data ---

# TRIAL WINDOW
pupil_avgs = readMat(paste(loadPath, 'avgPupil.mat', sep=""))
pupil_stat = readMat(paste(loadPath, 'pupilStats.mat', sep=""))

p_trial = pupil_avgs$avg.pupil.trim
p_base = pupil_avgs$avg.pupil.base

p_trial_df = data.frame('time'=t(pupil_avgs$tb.mem), 'AT_AT'=t(p_trial[[17]][[1]]), 'AT_AT_SEM'=t(p_trial[[18]][[1]]), 'AT_AS'=t(p_trial[[14]][[1]]), 'AT_AS_SEM'=t(p_trial[[15]][[1]]),
                        'AT_none'=t(p_trial[[11]][[1]]), 'AT_none_SEM'=t(p_trial[[12]][[1]]), 'AS_AT'=t(p_trial[[8]][[1]]), 'AS_AT_SEM'=t(p_trial[[9]][[1]]),
                        'AS_AS'=t(p_trial[[5]][[1]]), 'AS_AS_SEM'=t(p_trial[[6]][[1]]), 'AS_none'=t(p_trial[[2]][[1]]), 'AS_none_SEM'=t(p_trial[[3]][[1]]),
                        'VT_AT'=t(p_trial[[35]][[1]]), 'VT_AT_SEM'=t(p_trial[[36]][[1]]), 'VT_AS'=t(p_trial[[32]][[1]]), 'VT_AS_SEM'=t(p_trial[[33]][[1]]),
                        'VT_none'=t(p_trial[[29]][[1]]), 'VT_none_SEM'=t(p_trial[[30]][[1]]), 'VS_AT'=t(p_trial[[26]][[1]]), 'VS_AT_SEM'=t(p_trial[[27]][[1]]),
                        'VS_AS'=t(p_trial[[23]][[1]]), 'VS_AS_SEM'=t(p_trial[[24]][[1]]), 'VS_none'=t(p_trial[[20]][[1]]), 'VS_none_SEM'=t(p_trial[[21]][[1]]))


# BASELINE WINDOW
temp = readMat(paste(loadPath,'avgPupil_BL.mat',sep=""))
pupil_base = temp$base.means.indSs
condnames = unlist(temp$allconds)
Nblk = length(condnames)
Nss = nrow(pupil_base)
WM_mod = c('a','a','a','a','a','a','v','v','v','v','v','v')
WM_dom = c('s','s','s','t','t','t','s','s','s','t','t','t')
percept = c('none','as','at','none','as','at','none','as','at','none','as','at')
bl_data = c(pupil_base) # c() in this usage vectorizes the matrix by rows first, e.g.:
# 1 3 5 -->
# 2 4 6 --> 1 2 3 4 5 6

pupil_base_df = data.frame('subj'=rep(1:Nss, Nblk), 'WMmod'=rep(WM_mod, each=Nss),
                           'WMdom'=rep(WM_dom, each=Nss), 'INTtask'=rep(percept, each=Nss),
                           'BLmean'=bl_data)
pupil_base_df$subj = factor(pupil_base_df$subj)
pupil_base_df$WMmod = factor(pupil_base_df$WMmod)
pupil_base_df$memcond = interaction(pupil_base_df$WMmod, pupil_base_df$WMdom)
pupil_base_df$memcond = factor(pupil_base_df$memcond)
pupil_base_df$memcond = factor(pupil_base_df$memcond, levels(pupil_base_df$memcond)[c(3,1,4,2)]) # re-order levels
pupil_base_df$WMdom = factor(pupil_base_df$WMdom)
pupil_base_df$WMdom = factor(pupil_base_df$WMdom, levels(pupil_base_df$WMdom)[c(2,1)]) # temporal before spatial
pupil_base_df$INTtask = factor(pupil_base_df$INTtask)
pupil_base_df$INTtask = factor(pupil_base_df$INTtask, levels(pupil_base_df$INTtask)[c(3,2,1)])

bl_means = bl_means_sem = c(NA)
length(bl_means) = length(bl_means_sem) = 12
for (i in 1:12) {
  currdata = pupil_base_df$BLmean[pupil_base_df$WMmod==WM_mod[i] & pupil_base_df$WMdom==WM_dom[i] &
                                    pupil_base_df$INTtask==percept[i]]
  bl_means[i] = mean(currdata)
  bl_means_sem[i] = sd(currdata) / sqrt(Nss)
}

pupil_base_avg = data.frame('WMmod'=WM_mod, 'WMdom'=WM_dom, 'mem'=interaction(WM_mod, WM_dom),
                            'INTtask'=percept, 'BL'=bl_means, 'BL_SEM'=bl_means_sem)
pupil_base_avg$WMmod = factor(pupil_base_avg$WMmod)
pupil_base_avg$WMdom = factor(pupil_base_avg$WMdom)
pupil_base_avg$WMdom = factor(pupil_base_avg$WMdom, levels = levels(pupil_base_avg$WMdom)[c(2,1)])
pupil_base_avg$mem = factor(pupil_base_avg$mem)
pupil_base_avg$mem = factor(pupil_base_avg$mem, levels = levels(pupil_base_avg$mem)[c(3,1,4,2)])
pupil_base_avg$INTtask = factor(pupil_base_avg$INTtask)
pupil_base_avg$INTtask = factor(pupil_base_avg$INTtask, levels = levels(pupil_base_avg$INTtask)[c(3,2,1)])

# Same as above, but first averaged across WM domain (no significant effects)
WM_mod2 = c('a','a','a','v','v','v')
percept2 = c('none','as','at','none','as','at')
bl_data2 = rowMeans(cbind(pupil_base_df$BLmean[pupil_base_df$WMmod=='a' & pupil_base_df$WMdom=='s' &
                                                 pupil_base_df$INTtask=='none'],
                          pupil_base_df$BLmean[pupil_base_df$WMmod=='a' & pupil_base_df$WMdom=='t' &
                                                 pupil_base_df$INTtask=='none']))
for (i in 2:length(WM_mod2)) {
  newdata = rowMeans(cbind(pupil_base_df$BLmean[pupil_base_df$WMmod==WM_mod2[i] &
                                                  pupil_base_df$WMdom=='s' &
                                                  pupil_base_df$INTtask==percept2[i]],
                           pupil_base_df$BLmean[pupil_base_df$WMmod==WM_mod2[i] &
                                                  pupil_base_df$WMdom=='t'&
                                                  pupil_base_df$INTtask==percept2[i]]))
  bl_data2 = c(bl_data2, newdata)
}
pupil_base_acxDom = data.frame('subj'=rep(1:Nss, Nblk/2), 'WMmod'=rep(WM_mod2, each=Nss),
                               'INTtask'=rep(percept2, each=Nss), 'BLmean'=bl_data2)
pupil_base_acxDom$subj = factor(pupil_base_acxDom$subj)
pupil_base_acxDom$WMmod = factor(pupil_base_acxDom$WMmod)
pupil_base_acxDom$INTtask = factor(pupil_base_acxDom$INTtask)
pupil_base_acxDom$INTtask = factor(pupil_base_acxDom$INTtask, levels(pupil_base_acxDom$INTtask)[c(3,2,1)])
pupil_base_acxDom$cInter = interaction(pupil_base_acxDom$WMmod, pupil_base_acxDom$INTtask)

# Some basic stats on the baseline data, averaged across domain
# Note: stats lose significance with this level of averaging, but figure may be ok if significance comes from the
#       main effect of WM modality in the full INT-task specific models.
STAT_pupil_base = ezANOVA(pupil_base_acxDom, dv=BLmean, wid=subj, within=.(WMmod, INTtask))
# Generate the linear model
pupil_base_lme = lme(BLmean ~ cInter, random = ~1|subj, correlation=corCompSymm(form=~1|subj), data=pupil_base_acxDom)
# Tukey's posthoc testing
pupil_base_posthoc = summary(glht(pupil_base_lme, linfct=mcp(cInter = "Tukey")), test = adjusted(type = "holm"))
print('Pupil Baseline Post-Hoc Tests')
print(pupil_base_posthoc)
print("WARNING! Bonf-Holm undid, then re-did for 3 important comparisons, MANUALLY!")
# Useful for sorting the pvalues from bonf holm
# sort(abs(pupil_base_posthoc$test$tstat), decreasing=TRUE)
# Try that ANOVA again excluding the No Interfering task condition
pupil_base_acxDom_INT = pupil_base_acxDom
pupil_base_acxDom_INT = pupil_base_acxDom_INT[pupil_base_acxDom_INT$INTtask != 'none',]
pupil_base_acxDom_INT$INTtask = factor(pupil_base_acxDom_INT$INTtask)
pupil_base_acxDom_INT$cInter = factor(pupil_base_acxDom_INT$cInter)
STAT_pupil_base_INT = ezANOVA(pupil_base_acxDom_INT, dv=BLmean, wid=subj, within=.(WMmod, INTtask))
pupil_base_lme_INT = lme(BLmean ~ cInter, random = ~1|subj, correlation=corCompSymm(form=~1|subj), data=pupil_base_acxDom_INT)
pupil_base_posthoc_INT = summary(glht(pupil_base_lme_INT, linfct=mcp(cInter = "Tukey")), test = adjusted(type = "holm"))
print('Pupil Baseline Post-Hoc Tests')
print(pupil_base_posthoc_INT)

# Average the collapsed data
bl_means2 = bl_means_sem2 = c(NA)
length(bl_means2) = length(bl_means_sem2) = 6
for (i in 1:6) {
  currdata = pupil_base_acxDom$BLmean[pupil_base_acxDom$WMmod==WM_mod2[i] & pupil_base_acxDom$INTtask==percept2[i]]
  bl_means2[i] = mean(currdata)
  bl_means_sem2[i] = sd(currdata) / sqrt(Nss)
}
pupil_base_acxDom_avg = data.frame('WMmod'=WM_mod2, 'INTtask'=percept2, 'BL'=bl_means2, 'BL_SEM'=bl_means_sem2)
pupil_base_acxDom_avg$WMmod = factor(pupil_base_acxDom_avg$WMmod)
pupil_base_acxDom_avg$INTtask = factor(pupil_base_acxDom_avg$INTtask)
pupil_base_acxDom_avg$INTtask = factor(pupil_base_acxDom_avg$INTtask,
                                       levels = levels(pupil_base_acxDom_avg$INTtask)[c(3,2,1)])
pupil_base_acxDom_avg$cInter = interaction(pupil_base_acxDom_avg$WMmod, pupil_base_acxDom_avg$INTtask)


## --- DF for behav-pupil correlation ---
pupil_metrics = readMat(paste(loadPath, 'pupil_metrics.mat', sep=""))
blk_order = unlist(pupil_metrics$blk.order)
peak_lat = pupil_metrics$pupil.metrics[,,1]
peak_amp = pupil_metrics$pupil.metrics[,,2]
rm(list='pupil_metrics')
mapping = NULL
for (i in 1:length(mem_m)) {
  bcond = paste(mem_m[i], mem_d[i], "_", int[i], sep="")
  mapping = c(mapping, which(blk_order == bcond))
}
peak_lat = peak_lat[,mapping]
peak_lat = peak_lat[,int!='none']
peak_lat = peak_lat / 500 # samples --> sec
peak_lat_avg = colMeans(peak_lat)
peak_lat_sem = NULL
for (c in 1:ncol(peak_lat)) {
  peak_lat_sem[c] = sd(peak_lat[,c]) / sqrt(Nss)
}
blk_order_red = blk_order[mapping]
blk_order_red = blk_order_red[int!='none']
pupil_behav_df = data.frame('mem'=rep(c('at','as','vt','vs'), each=2), 'int'=rep(c('at','as'), 4),
                            'INTbehav'=error_df$error_int[error_df$int!='none'],
                            'INTbehavSEM'=error_df$error_intSEM[error_df$int!='none'],
                            'peakLat'=peak_lat_avg, 'peakLatSEM'=peak_lat_sem)
pupil_behav_df$mem = factor(pupil_behav_df$mem)
pupil_behav_df$mem = factor(pupil_behav_df$mem, levels=levels(pupil_behav_df$mem)[c(2,1,4,3)])
pupil_behav_df$int = factor(pupil_behav_df$int)
pupil_behav_df$int = factor(pupil_behav_df$int, levels=levels(pupil_behav_df$int)[c(2,1)])


## --- Compare peak pupil diameter to network imbalance ---

peak_amp = peak_amp[,mapping]
peak_amp = peak_amp[,int!='none']
peak_amp_avg = colMeans(peak_amp)
peak_amp_sem = NULL
for (c in 1:ncol(peak_amp)) {
  peak_amp_sem[c] = sd(peak_amp[,c]) / sqrt(Nss)
}
# NI = Network Imbalance; a simple additive model of LFC network activation, quantifying the
# degree to which a pair of tasks drives both networks as opposed to loading onto one
NI = c(1.56, 1.14, 1.14, 0.72, 0.99, 0.57, 0.38, 0.8) # pulled from "M2_driveByMichalka.xlsx"
pupil_model_df = data.frame('mem'=rep(c('at','as','vt','vs'), each=2), 'int'=rep(c('at','as'), 4),
                            'peakAmp'=peak_amp_avg, 'peakAmpSEM'=peak_amp_sem, 'NI'=NI,
                            'NIsem_FAKE'=rep(0,8))
pupil_model_df$mem = factor(pupil_model_df$mem)
pupil_model_df$mem = factor(pupil_model_df$mem, levels=levels(pupil_model_df$mem)[c(2,1,4,3)])
pupil_model_df$int = factor(pupil_model_df$int)
pupil_model_df$int = factor(pupil_model_df$int, levels=levels(pupil_model_df$int)[c(2,1)])
pupil_model_df$cond = interaction(pupil_model_df$mem, pupil_model_df$int)


## --- Load and format the ERP data ---
erp_avgs = readMat(paste(loadPath, 'avg_int1_ERP.mat', sep=""))

avg = erp_avgs$int1.avg
sem = erp_avgs$int1.sem
tb = erp_avgs$tb
tb = tb*1000 # convert sec to ms

nameInds = seq(1, length(avg), 2) # odd indices read as condition names
dataInds = nameInds + 1
keyconds = unlist(avg[nameInds])
avgdata = avg[dataInds]
semdata = sem[dataInds]

# Create a data frame with placeholder names
avg_df = data.frame('time'=t(tb), 'at_at'=t(avgdata[[1]][[1]]), 'as_at'=t(avgdata[[2]][[1]]), 'vt_at'=t(avgdata[[3]][[1]]), 'vs_at'=t(avgdata[[4]][[1]]),
                    'at_as'=t(avgdata[[5]][[1]]), 'as_as'=t(avgdata[[6]][[1]]), 'vt_as'=t(avgdata[[7]][[1]]), 'vs_as'=t(avgdata[[8]][[1]]))
sem_df = data.frame('time'=t(tb), 'at_at_sem'=t(semdata[[1]][[1]]), 'as_at_sem'=t(semdata[[2]][[1]]), 'vt_at_sem'=t(semdata[[3]][[1]]), 'vs_at_sem'=t(semdata[[4]][[1]]),
                    'at_as_sem'=t(semdata[[5]][[1]]), 'as_as_sem'=t(semdata[[6]][[1]]), 'vt_as_sem'=t(semdata[[7]][[1]]), 'vs_as_sem'=t(semdata[[8]][[1]]))
# Cut out the baseline and trim the last 200 ms
# avg_df = avg_df[avg_df$time >= 0 & avg_df$time <= 300,]
# sem_df = sem_df[sem_df$time >= 0 & sem_df$time <= 300,]

# Combine avg and sem data into a single df
erp_df = cbind(avg_df, sem_df[,2:ncol(sem_df)])



erp_mat = readMat(paste(loadPath, 'ERP_peaks.mat', sep=""))
ERP_peaks = erp_mat$ERP.peaks
condnames_INT = unlist(erp_mat$cnames.int)
Nss = nrow(ERP_peaks)
WM_mod = unlist(lapply(condnames_INT, substr, 1, 1))
WM_dom = unlist(lapply(condnames_INT, substr, 2, 2))
int_task = unlist(lapply(condnames_INT, substr, 4, 5))

ERP_peak_df = data.frame('subj'=rep(1:Nss, length(condnames_INT)), 'mod'=rep(WM_mod, each=Nss),
                         'dom'=rep(WM_dom, each=Nss), 'int'=rep(int_task, each=Nss),
                         'peak'=c(ERP_peaks))
ERP_peak_df$subj = factor(ERP_peak_df$subj)
ERP_peak_df$mod = factor(ERP_peak_df$mod)
ERP_peak_df$dom = factor(ERP_peak_df$dom)
ERP_peak_df$dom = factor(ERP_peak_df$dom, levels=levels(ERP_peak_df$dom)[c(2,1)])
ERP_peak_df$int = factor(ERP_peak_df$int)
ERP_peak_df$int = factor(ERP_peak_df$int, levels=levels(ERP_peak_df$int)[c(2,1)])
ERP_peak_df$domint = interaction(ERP_peak_df$dom, ERP_peak_df$int)

# # Aud vs Vis ERPs (avg acx domain and INT)
# ERP_peaks = readMat(paste(loadPath, 'INT1_peaks.mat', sep=""))
# ERP_peaks = ERP_peaks$peaks # they come in averaged across WM domain
# a = colMeans(rbind(ERP_peaks[[1]], ERP_peaks[[2]], ERP_peaks[[5]], ERP_peaks[[6]]))
# v = colMeans(rbind(ERP_peaks[[3]], ERP_peaks[[4]], ERP_peaks[[7]], ERP_peaks[[8]]))
# cnames = c('aWM', 'vWM')
# # Individual Data
# erp_peaks_modOnly = data.frame('subj'=rep(1:Nss, 2), 'WMmod'=rep(c('a','v'), each=Nss), 
#                                'peak'=c(a, v))
# erp_peaks_modOnly$subj = factor(erp_peaks_modOnly$subj)
# erp_peaks_modOnly$WMmod = factor(erp_peaks_modOnly$WMmod)
# 
# # ERP averages across domain and INT task (Modality main effect)
# red_WM_mod = c("a","v")
# erp_means = erp_means_sem = c(NA)
# length(erp_means) = length(erp_means_sem) = 2
# for (i in 1:2) {
#   currdata = erp_peaks_modOnly$peak[erp_peaks_modOnly$WMmod==red_WM_mod[i]]
#   erp_means[i] = mean(currdata)
#   erp_means_sem[i] = sd(currdata) / sqrt(Nss)
# }
# erp_peaks_modOnly_avg = data.frame('WMmod'=red_WM_mod, 'peak'=erp_means, 'peak_sem'=erp_means_sem)
# erp_peaks_modOnly_avg$WMmod = factor(erp_peaks_modOnly_avg$WMmod)
# 
# 
# # For ERP domain differences (domain:interfering interaction) -- avg acx modality
# t_at = colMeans(rbind(ERP_peaks[[1]], ERP_peaks[[3]]))
# s_at = colMeans(rbind(ERP_peaks[[2]], ERP_peaks[[4]]))
# t_as = colMeans(rbind(ERP_peaks[[5]], ERP_peaks[[7]]))
# s_as = colMeans(rbind(ERP_peaks[[6]], ERP_peaks[[8]]))
# erp_acx_mod = data.frame('subj'=rep(1:Nss, 4), 'WMdom'=rep(c('t','s','t','s'), each=Nss),
#                          'int'=rep(c('at','at','as','as'), each=Nss), 
#                          'erp'=c(t_at, s_at, t_as, s_as))
# erp_acx_mod$subj = factor(erp_acx_mod$subj)
# erp_acx_mod$WMdom = factor(erp_acx_mod$WMdom)
# erp_acx_mod$WMdom = factor(erp_acx_mod$WMdom, levels=levels(erp_acx_mod$WMdom)[c(2,1)])
# erp_acx_mod$int = factor(erp_acx_mod$int)
# erp_acx_mod$int = factor(erp_acx_mod$int, levels=levels(erp_acx_mod$int)[c(2,1)])
# erp_acx_mod$interact = interaction(erp_acx_mod$WM, erp_acx_mod$int)
# 
# # Average ERP peaks after collapsing across modality
# erp_means = erp_means_sem = c(NA)
# length(erp_means) = length(erp_means_sem) = 4
# for (i in 1:4) {
#   currdata = erp_acx_mod$erp[erp_acx_mod$interact==levels(erp_acx_mod$interact)[i]]
#   erp_means[i] = mean(currdata)
#   erp_means_sem[i] = sd(currdata) / sqrt(Nss)
# }
# erp_acx_mod_avg = data.frame('interact'=levels(erp_acx_mod$interact), 'WMdom'=c('t','s','t','s'),
#                              'int'=c('at','at','as','as'), 'erp'=erp_means, 'erp_sem'=erp_means_sem)
# erp_acx_mod_avg$interact = factor(erp_acx_mod_avg$interact)
# erp_acx_mod_avg$interact = factor(erp_acx_mod_avg$interact, levels=levels(erp_acx_mod_avg$interact)[c(4,2,3,1)])
# erp_acx_mod_avg$WMdom = factor(erp_acx_mod_avg$WMdom)
# erp_acx_mod_avg$WMdom = factor(erp_acx_mod_avg$WMdom, levels=levels(erp_acx_mod_avg$WMdom)[c(2,1)])
# erp_acx_mod_avg$int = factor(erp_acx_mod_avg$int)
# erp_acx_mod_avg$int = factor(erp_acx_mod_avg$int, levels=levels(erp_acx_mod_avg$int)[c(2,1)])


## --- Load and format the TF timecourse data ---

power_timecourses = readMat(paste(loadPath, 'simple_timecourses.mat', sep=""))
conds = unlist(power_timecourses$condnames)
a = power_timecourses$alphapow.mean
a_sem = power_timecourses$alphapow.sem
a_domCol = power_timecourses$alphapow.mean.domCollapsed
a_domCol_sem = power_timecourses$alphapow.sem.domCollapsed
t = power_timecourses$thetapow.mean
t_sem = power_timecourses$thetapow.sem
tb_alpha = power_timecourses$tb.alpha
tb_theta = power_timecourses$tb.theta
alpha_df = data.frame('time'=t(tb_alpha), 'as_as'=a[1,], 'as_at'=a[2,], 'as_none'=a[3,], 'at_as'=a[4,], 'at_at'=a[5,],
                      'at_none'=a[6,], 'vs_as'=a[7,], 'vs_at'=a[8,], 'vs_none'=a[9,], 'vt_as'=a[10,], 'vt_at'=a[11,],
                      'vt_none'=a[12,], 'as_as_sem'=a_sem[1,], 'as_at_sem'=a_sem[2,], 'as_none_sem'=a_sem[3,],
                      'at_as_sem'=a_sem[4,], 'at_at_sem'=a_sem[5,], 'at_none_sem'=a_sem[6,], 'vs_as_sem'=a_sem[7,],
                      'vs_at_sem'=a_sem[8,], 'vs_none_sem'=a_sem[9,], 'vt_as_sem'=a_sem[10,], 'vt_at_sem'=a_sem[11,],
                      'vt_none_sem'=a_sem[12,])
alpha_df_domCol = data.frame('time'=t(tb_alpha), 'a_as'=a_domCol[1,], 'a_at'=a_domCol[2,], 'a_none'=a_domCol[3,],
                             'v_as'=a_domCol[4,],'v_at'=a_domCol[5,],'v_none'=a_domCol[6,],
                             'a_as_sem'=a_domCol_sem[1,], 'a_at_sem'=a_domCol_sem[2,], 'a_none_sem'=a_domCol_sem[3,],
                             'v_as_sem'=a_domCol_sem[4,], 'v_at_sem'=a_domCol_sem[5,], 'v_none_sem'=a_domCol_sem[6,])
# Theta pre-Interfering task
theta_df = data.frame('time'=t(tb_theta), 'as_as'=t[1,], 'as_at'=t[2,], 'as_none'=t[3,], 'at_as'=t[4,], 'at_at'=t[5,],
                      'at_none'=t[6,], 'vs_as'=t[7,], 'vs_at'=t[8,], 'vs_none'=t[9,], 'vt_as'=t[10,], 'vt_at'=t[11,],
                      'vt_none'=t[12,], 'as_as_sem'=t_sem[1,], 'as_at_sem'=t_sem[2,], 'as_none_sem'=t_sem[3,],
                      'at_as_sem'=t_sem[4,], 'at_at_sem'=t_sem[5,], 'at_none_sem'=t_sem[6,], 'vs_as_sem'=t_sem[7,],
                      'vs_at_sem'=t_sem[8,], 'vs_none_sem'=t_sem[9,], 'vt_as_sem'=t_sem[10,], 'vt_at_sem'=t_sem[11,],
                      'vt_none_sem'=t_sem[12,])

# Theta from interfering task to end of memory retention
t_wh = power_timecourses$thetapow.wh # individual subjects data
t_wh_mean = power_timecourses$thetapow.wh.mean
t_wh_sem = power_timecourses$thetapow.wh.sem
tb_theta_wh = power_timecourses$tb.theta.wh

theta_wh_df = data.frame('time'=t(tb_theta_wh),'as_as'=t_wh_mean[1,], 'as_at'=t_wh_mean[2,], 'as_none'=t_wh_mean[3,], 
                         'at_as'=t_wh_mean[4,], 'at_at'=t_wh_mean[5,], 'at_none'=t_wh_mean[6,], 'vs_as'=t_wh_mean[7,], 
                         'vs_at'=t_wh_mean[8,], 'vs_none'=t_wh_mean[9,], 'vt_as'=t_wh_mean[10,], 'vt_at'=t_wh_mean[11,],
                         'vt_none'=t_wh_mean[12,], 'as_as_sem'=t_wh_sem[1,], 'as_at_sem'=t_wh_sem[2,], 'as_none_sem'=t_wh_sem[3,],
                         'at_as_sem'=t_wh_sem[4,], 'at_at_sem'=t_wh_sem[5,], 'at_none_sem'=t_wh_sem[6,], 'vs_as_sem'=t_wh_sem[7,],
                         'vs_at_sem'=t_wh_sem[8,], 'vs_none_sem'=t_wh_sem[9,], 'vt_as_sem'=t_wh_sem[10,], 'vt_at_sem'=t_wh_sem[11,],
                         'vt_none_sem'=t_wh_sem[12,])
# Same as previous, averaged across WM domains
t_wh_diff = power_timecourses$thetapow.wh.diff # individual subjects data
t_wh_diff_mean = power_timecourses$thetapow.wh.diff.mean
t_wh_diff_sem = power_timecourses$thetapow.wh.diff.sem

theta_wh_diff_df = data.frame('time'=t(tb_theta_wh),'as_as'=t_wh_diff_mean[1,], 'as_at'=t_wh_diff_mean[2,], 'at_as'=t_wh_diff_mean[3,], 
                              'at_at'=t_wh_diff_mean[4,], 'vs_as'=t_wh_diff_mean[5,], 'vs_at'=t_wh_diff_mean[6,], 
                              'vt_as'=t_wh_diff_mean[7,], 'vt_at'=t_wh_diff_mean[8,],
                              'as_as_sem'=t_wh_diff_sem[1,], 'as_at_sem'=t_wh_diff_sem[2,], 'at_as_sem'=t_wh_diff_sem[3,], 
                              'at_at_sem'=t_wh_diff_sem[4,], 'vs_as_sem'=t_wh_diff_sem[5,], 'vs_at_sem'=t_wh_diff_sem[6,], 
                              'vt_as_sem'=t_wh_diff_sem[7,], 'vt_at_sem'=t_wh_diff_sem[8,])


rm(power_timecourses)


## --- Load and format the baseline power spectrum data ---
BL_pow = readMat(paste(loadPath, 'BL_powspectra.mat', sep=""))
BL_pow_df = data.frame('freq'=1:80, 'Pz_mean'=t(BL_pow$BL.ps.Pz.mean), 'Pz_sem'=t(BL_pow$BL.ps.Pz.sem),
                       'AFz_mean'=t(BL_pow$BL.ps.AFz.mean), 'AFz_sem'=t(BL_pow$BL.ps.AFz.sem))
# Only plot up to 60 Hz
BL_pow_df = BL_pow_df[BL_pow_df$freq <= 60,]







# *************************************************************************************
# _____________________________________________________________________________________
#
# --- PLOTTING ---  
# _____________________________________________________________________________________
# *************************************************************************************


# 2A) WM Task Error Rates: No INT conditions only

if (genPlots[1] == 1) {
  # Restrict to No INT conditions
  error_df_red = error_df[error_df$int=="none",]
  error_df_red$int = factor(error_df_red$int)
  error_df_red$interact = interaction(error_df_red$mod, error_df_red$dom)
  error_df_red$interact = factor(error_df_red$interact, levels=levels(error_df_red$interact)[c(1,3,2,4)])
  error_df_indss_red = error_df_indss[error_df_indss$int=="none",]
  error_df_indss_red$int = factor(error_df_indss_red$int)
  error_df_indss_red$interact = interaction(error_df_indss_red$mod, error_df_indss_red$dom)
  error_df_indss_red$interact = factor(error_df_indss_red$interact, levels=levels(error_df_indss_red$interact)[c(1,3,2,4)])
  # Jitter the x-axis positions 
  error_df_indss_red_jit = error_df_indss_red %>%
    mutate(mem_jit = as.numeric(int)*0.265-0.265 + jitter(as.numeric(mem), 0.3),
           grouping = interaction(subID, mem))
  # Create plot
  tplot = ggplot(data=error_df_indss_red_jit, aes(x=mem_jit, y=err_mem)) +
    # individual subjects data points
    geom_point(fill='#CCCCCC', shape=21, size=4, stroke=1,
               alpha=1, color='black') +
    # averages
    geom_errorbar(data=error_df_red, aes(x=mem, ymin=error - errorSEM,
                                         ymax=error + errorSEM, color=interact),
                  size=1.25, width=0.1, inherit.aes= FALSE) +
    geom_segment(aes(x=0.85, xend=1.15, y=error_df_red$error[1], yend=error_df_red$error[1]),
                 size=2.5, color=condColors_WM[1], inherit.aes=FALSE) +
    geom_segment(aes(x=1.85, xend=2.15, y=error_df_red$error[2], yend=error_df_red$error[2]),
                 size=2.5, color=condColors_WM[2], inherit.aes=FALSE) +
    geom_segment(aes(x=2.85, xend=3.15, y=error_df_red$error[3], yend=error_df_red$error[3]),
                 size=2.5, color=condColors_WM[3], inherit.aes=FALSE) +
    geom_segment(aes(x=3.85, xend=4.15, y=error_df_red$error[4], yend=error_df_red$error[4]),
                 size=2.5, color=condColors_WM[4], inherit.aes=FALSE) +
    
    # geom_point(data=error_df_red, aes(x=mem, y=error, fill=interact, color=interact), shape=21,
    #            size=6, stroke=2, inherit.aes= FALSE) +
    scale_fill_manual(name='WM Task', breaks=levels(error_df_indss_red$interact),
                      labels=c('AT','AS','VT','VS'), values=condColors_WM, guide=FALSE) +
    scale_color_manual(name='WM Task', breaks=levels(error_df_indss_red$interact),
                       labels=c('AT','AS','VT','VS'), values=condColors_WM, guide=FALSE) +
    scale_x_discrete(name='', breaks=c('a.t','a.s','v.t','v.s'),
                     labels=c('AT','AS','VT','VS')) +
    scale_y_continuous(name='WM Error Rate', breaks=seq(0,0.6,0.2), expand=c(0,0)) +
    coord_cartesian(ylim=c(-0.05, 0.7)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 1,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=32, colour = "black"),
          axis.text.y=element_text(size=32, colour = "black"),
          axis.title.x = element_text(size=40, vjust=0.1),
          axis.title.y=element_text(size=40, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_2a_WM_errors_NoINT.svg',sep=""), width=10,height=7)
  }
}


# _____________________________________________________________________________________________________ 

# 2B) WM Task Error Rates: AT and AS INT relative to No INT

if (genPlots[2] == 1) {
  # Construct new error difference data frames
  # Individual Subjects
  memconds = levels(error_df_indss$mem)
  diff_at = error_df_indss$err_mem[error_df_indss$int=="at"] -
    error_df_indss$err_mem[error_df_indss$int=="none"]
  diff_as = error_df_indss$err_mem[error_df_indss$int=="as"] -
    error_df_indss$err_mem[error_df_indss$int=="none"]
  error_df_indss_diff = data.frame('int'=rep(c('at','as'), each=Nss*4),
                                   'subj'=rep(rep(1:Nss, each=4), 2),
                                   'mem'=rep(memconds, Nss*2),
                                   'relerr'=c(diff_at, diff_as))
  error_df_indss_diff$int = factor(error_df_indss_diff$int)
  error_df_indss_diff$int = factor(error_df_indss_diff$int, levels=levels(error_df_indss_diff$int)[c(2,1)])
  error_df_indss_diff$subj = factor(error_df_indss_diff$subj) 
  error_df_indss_diff$mem = factor(error_df_indss_diff$mem)
  error_df_indss_diff$mem = factor(error_df_indss_diff$mem, levels=levels(error_df_indss_diff$mem)[c(2,1,4,3)])
  # Averages
  diff_at_avg = diff_as_avg = diff_at_SEM = diff_as_SEM = rep(0, 4)
  memconds = levels(error_df_indss_diff$mem)
  for (wm in 1:length(memconds)) {
    curr_AT = error_df_indss_diff$relerr[error_df_indss_diff$mem==memconds[wm] &
                                           error_df_indss_diff$int=='at']
    diff_at_avg[wm] = mean(curr_AT)
    diff_at_SEM[wm] = sd(curr_AT) / sqrt(Nss)
    curr_AS = error_df_indss_diff$relerr[error_df_indss_diff$mem==memconds[wm] &
                                           error_df_indss_diff$int=='as']
    diff_as_avg[wm] = mean(curr_AS)
    diff_as_SEM[wm] = sd(curr_AS) / sqrt(Nss)
  }
  error_df_diff = data.frame('mem'=rep(memconds, 2), 'int'=rep(c('at','as'), each=4),
                             'relerr'=c(diff_at_avg, diff_as_avg),
                             'relerr.SEM'=c(diff_at_SEM, diff_as_SEM))
  error_df_diff$mem = factor(error_df_diff$mem)
  error_df_diff$mem = factor(error_df_diff$mem, levels=levels(error_df_diff$mem)[c(2,1,4,3)])
  error_df_diff$int = factor(error_df_diff$int)
  error_df_diff$int = factor(error_df_diff$int, levels=levels(error_df_diff$int)[c(2,1)])
  for (wm in 1:length(memconds)) {
    # Restrict to the current WM condition (and jitter ind Ss points)
    error_df_indss_diff_jit = error_df_indss_diff[error_df_indss_diff$mem==memconds[wm],] %>%
      mutate(int_jit = jitter(as.numeric(int), 0.3),
             grouping = interaction(subj, int))
    error_df_indss_diff_jit$mem = factor(error_df_indss_diff_jit$mem)
    error_df_diff_red = error_df_diff[error_df_diff$mem==memconds[wm],]
    error_df_diff_red$mem = factor(error_df_diff_red$mem)
    # Create plot
    tplot = ggplot(data=error_df_indss_diff_jit, aes(x=int_jit, y=relerr)) +
      # horizontal line for zero change
      geom_hline(yintercept=0, color='black', linetype=2, size=2) +
      # individual subjects data points
      geom_point(fill='#CCCCCC', shape=21, size=4, stroke=1,
                 alpha=1, color='black') +
      # averages
      geom_errorbar(data=error_df_diff_red, aes(x=int, ymin=relerr - relerr.SEM,
                                                ymax=relerr + relerr.SEM),
                    color=condColors_WM[wm], size=1.25, width=0.1, inherit.aes= FALSE) +
      # geom_point(data=error_df_diff_red, aes(x=int, y=relerr), fill=condColors_WM[wm], 
      #            color=condColors_WM[wm], shape=21, size=6, stroke=2, inherit.aes= FALSE) +
      geom_segment(aes(x=0.85, xend=1.15, y=error_df_diff_red$relerr[1], yend=error_df_diff_red$relerr[1]),
                   size=2.5, color=condColors_WM[wm], inherit.aes=FALSE) +
      geom_segment(aes(x=1.85, xend=2.15, y=error_df_diff_red$relerr[2], yend=error_df_diff_red$relerr[2]),
                   size=2.5, color=condColors_WM[wm], inherit.aes=FALSE) +
      scale_x_discrete(name='', breaks=c('at','as'),
                       labels=c('AT','AS')) +
      scale_y_continuous(name='Error Difference (r.e. No INT)', breaks=seq(-0.2,0.4,0.2), expand=c(0,0)) +
      coord_cartesian(ylim=c(-0.3, 0.4)) +
      theme(panel.background = element_blank(),
            panel.grid.major.y = element_blank(),
            panel.grid.major.x = element_blank(),
            aspect.ratio = 1.6,
            panel.border = element_rect(colour = "black", fill = NA, size=2),
            plot.margin=unit(c(2,2,2,2), "lines"),
            legend.direction = 'vertical',
            axis.text.x = element_text(size=32, colour = "black"),
            axis.text.y=element_text(size=32, colour = "black"),
            axis.title.x = element_text(size=40, vjust=0.1),
            axis.title.y=element_text(size=40, vjust=0.2))
    
    windows()
    print(tplot)
    
    if (savePlots == 1) {
      ggsave(filename = paste(imgSaveDir,'ms_2b_WM_relerrors_',substr(memconds[wm],1,1),
                              substr(memconds[wm],3,3),'WM.svg',sep=""), width=10,height=7)
    }
  }
}


# _____________________________________________________________________________________________________ 

# 3) INT Error Rates by WM task (group = INT task)

if (genPlots[3] == 1) {
  error_df_reduced = error_df_indss[error_df_indss$int!='none',]
  error_df_reduced$int = factor(error_df_reduced$int)
  error_df_reduced = error_df_reduced %>%
    mutate(int_jit = as.numeric(mem)*0.2-0.75 + jitter(as.numeric(int), 0.15),
           grouping = interaction(subID, int))
  error_df_avg_red = error_df[error_df$int %in% c('at','as'),]
  pd = position_dodge(0.8)
  tplot = ggplot(data=error_df_reduced) +
    # Layer 1: lines connecting individual points
    geom_line(aes(x=int_jit, y=err_int, group=grouping), color='#a6a6a6', size=0.8, position=position_nudge(x=0.25)) +
    # Layer 2: individual data points (note: sizes scaled empirically to match Plot 1)   (sm. size=2.286 )
    geom_point(aes(x=int_jit, y=err_int, group=grouping), fill='#CCCCCC', color='black', shape=21, 
               size=3.81, stroke=0.8, position=position_nudge(x=0.25)) +
    # Layer 3: average points
    # geom_errorbar(data=error_df_avg_red, aes(x=int, ymin=error_int-error_intSEM, ymax=error_int+error_intSEM, group=mem),
    #               color='black', size=1, width=0.15, position=pd) +
    geom_errorbar(data=error_df_avg_red,
                  aes(x=int, ymin=error_int-error_intSEM, ymax=error_int+error_intSEM, group=mem, color=mem),
                  size=1.25, width=0.1, position=pd, inherit.aes= FALSE) +
    # geom_point(data=error_df_avg_red, aes(x=int, y=error_int, fill=mem), stroke=1, size=10.16, shape=21, color='black', position=pd) +
    # AT INT  average lines
    geom_segment(aes(x=0.65, xend=0.75, y=error_df_avg_red$error_int[1], yend=error_df_avg_red$error_int[1]),
                 size=2.5, color=condColors_WM[1], inherit.aes=FALSE) +
    geom_segment(aes(x=0.85, xend=0.95, y=error_df_avg_red$error_int[3], yend=error_df_avg_red$error_int[3]),
                 size=2.5, color=condColors_WM[2], inherit.aes=FALSE) +
    geom_segment(aes(x=1.05, xend=1.15, y=error_df_avg_red$error_int[5], yend=error_df_avg_red$error_int[5]),
                 size=2.5, color=condColors_WM[3], inherit.aes=FALSE) +
    geom_segment(aes(x=1.25, xend=1.35, y=error_df_avg_red$error_int[7], yend=error_df_avg_red$error_int[7]),
                 size=2.5, color=condColors_WM[4], inherit.aes=FALSE) +
    # AS INT average lines
    geom_segment(aes(x=1.65, xend=1.75, y=error_df_avg_red$error_int[2], yend=error_df_avg_red$error_int[2]),
                 size=2.5, color=condColors_WM[1], inherit.aes=FALSE) +
    geom_segment(aes(x=1.85, xend=1.95, y=error_df_avg_red$error_int[4], yend=error_df_avg_red$error_int[4]),
                 size=2.5, color=condColors_WM[2], inherit.aes=FALSE) +
    geom_segment(aes(x=2.05, xend=2.15, y=error_df_avg_red$error_int[6], yend=error_df_avg_red$error_int[6]),
                 size=2.5, color=condColors_WM[3], inherit.aes=FALSE) +
    geom_segment(aes(x=2.25, xend=2.35, y=error_df_avg_red$error_int[8], yend=error_df_avg_red$error_int[8]),
                 size=2.5, color=condColors_WM[4], inherit.aes=FALSE) +
    scale_fill_manual(name='WM Task', labels=c('AT','AS','VT','VS'),
                      values=condColors_WM, guide=FALSE) +
    scale_color_manual(name='WM Task', labels=c('AT','AS','VT','VS'),
                       values=condColors_WM, guide=FALSE) +
    scale_x_discrete(name='Interfering Task', breaks=c('at', 'as'),
                     labels=c('AT', 'AS')) +
    scale_y_continuous(name='INT Error Rate', breaks=seq(0,0.6,0.2), expand=c(0,0)) +
    coord_cartesian(ylim=c(-0.03,0.65)) +
    
    theme(panel.background = element_blank(),
          aspect.ratio = 0.7,
          panel.border = element_rect(colour = "black", fill = NA),
          plot.margin=unit(c(1,1,1,1), "lines"),
          legend.direction = 'vertical',
          legend.position = c(0.2,0.85),
          legend.background = element_rect(colour='black',size=1),
          legend.title = element_text(size=20),
          legend.text = element_text(size=18),
          legend.key = element_blank(),
          plot.title = element_text(size=14, hjust=0.5, margin = margin(t=10, b=-25)),
          axis.text.x = element_blank(), axis.text.y = element_blank(),
          axis.title.x = element_blank(), axis.title.y = element_blank())
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_3_INT_perc_corr.svg',sep=""), width=10,height=7)
  }
}

# _____________________________________________________________________________________________________

# 5B) Intervening task ERPs (int S1)
if (genPlots[4]==1) {
  ## 1: AT Interfering Task ERPs
  tplot = ggplot(data=erp_df, aes(x=time)) +
    # geom_rect(aes(xmin=185,xmax=225,ymin=-Inf,ymax=Inf), fill='#ebebeb') +
    geom_segment(x=185, y=13, xend=185, yend=Inf, size=1.5, linetype=1, color='#8c8c8cff') +
    geom_segment(x=225, y=13, xend=225, yend=Inf, size=1.5, linetype=1, color='#8c8c8cff') +
    
    # ribbon plots for SEM
    geom_ribbon(aes(ymin=at_at - at_at_sem, ymax=at_at + at_at_sem), fill=condColors_WM[1], alpha=0.2) +
    geom_ribbon(aes(ymin=as_at - as_at_sem, ymax=as_at + as_at_sem), fill=condColors_WM[2], alpha=0.2) +
    geom_ribbon(aes(ymin=vt_at - vt_at_sem, ymax=vt_at + vt_at_sem), fill=condColors_WM[3], alpha=0.2) +
    geom_ribbon(aes(ymin=vs_at - vs_at_sem, ymax=vs_at + vs_at_sem), fill=condColors_WM[4], alpha=0.2) +
    
    # add on group plots
    geom_line(aes(y=at_at), color=condColors_WM[1], size=3) +
    geom_line(aes(y=as_at), color=condColors_WM[2], size=3) +
    geom_line(aes(y=vt_at), color=condColors_WM[3], size=3) +
    geom_line(aes(y=vs_at), color=condColors_WM[4], size=3) + 
    
    scale_x_continuous(name='Time (ms)', breaks=seq(0,400,by=100), expand=c(0,0)) +
    scale_y_continuous(name='Scalp Potential (?V)', breaks=seq(-2,14,by=2), expand=c(0,0)) +
    coord_cartesian(xlim=c(0,450), ylim=c(-2.5,14)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 1,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=32, colour = "black"),
          axis.text.y=element_text(size=32, colour = "black"),
          axis.title.x = element_text(size=40, vjust=0.1),
          axis.title.y=element_text(size=40, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_5a_ERPs_INTs1_AT.svg',sep=""), width=8, height=8, units="in")
  }
  
  # AS Interfering task ERPs
  tplot = ggplot(data=erp_df, aes(x=time)) +
    # geom_rect(aes(xmin=185,xmax=225,ymin=-Inf,ymax=Inf), fill='#ebebeb') +
    geom_segment(x=185, y=13, xend=185, yend=Inf, size=1.5, linetype=1, color='#8c8c8cff') +
    geom_segment(x=225, y=13, xend=225, yend=Inf, size=1.5, linetype=1, color='#8c8c8cff') +
    
    # ribbon plots for SEM
    geom_ribbon(aes(ymin=at_as - at_as_sem, ymax=at_as + at_as_sem), fill=condColors_WM[1], alpha=0.2) +
    geom_ribbon(aes(ymin=as_as - as_as_sem, ymax=as_as + as_as_sem), fill=condColors_WM[2], alpha=0.2) +
    geom_ribbon(aes(ymin=vt_as - vt_as_sem, ymax=vt_as + vt_as_sem), fill=condColors_WM[3], alpha=0.2) +
    geom_ribbon(aes(ymin=vs_as - vs_as_sem, ymax=vs_as + vs_as_sem), fill=condColors_WM[4], alpha=0.2) +
    
    # add on group plots
    geom_line(aes(y=at_as), color=condColors_WM[1], size=3) +
    geom_line(aes(y=as_as), color=condColors_WM[2], size=3) +
    geom_line(aes(y=vt_as), color=condColors_WM[3], size=3) +
    geom_line(aes(y=vs_as), color=condColors_WM[4], size=3) + 
    
    scale_x_continuous(name='Time (ms)', breaks=seq(0,400,by=100), expand=c(0,0)) +
    scale_y_continuous(name='Scalp Potential (?V)', breaks=seq(-2,14,by=2), expand=c(0,0)) +
    coord_cartesian(xlim=c(0,450), ylim=c(-2.5,14)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 1,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=32, colour = "black"),
          axis.text.y=element_text(size=32, colour = "black"),
          axis.title.x = element_text(size=40, vjust=0.1),
          axis.title.y=element_text(size=40, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_5b_ERPs_INTs1_AS.svg',sep=""), width=8, height=8, units="in")
  }
}


# _____________________________________________________________________________________________________


# 5C) ERP peaks: Domain x Intervening task interaction plot
if (genPlots[5]==1) {
  # TEMPORAL INT TASK
  tplot = ERP_peak_df %>% group_by(subj, domint) %>% 
    filter(., (domint=="t.at" | domint=="s.at")) %>% summarise(sAvg = mean(peak)) %>%
    ggplot(data = .) +
    aes(x=domint, y=sAvg) +
    # Lines connecting individual points
    geom_line(aes(group=subj), color='black', size=1.5) +
    geom_point(fill='#CCCCCC', shape=21, size=5, stroke=1,
               alpha=1, color='#b6b6b6') +
    # Average and SEM
    stat_summary(fun.data = mean_se, geom = "errorbar", color='#000000', size=1, width=0.15) +
    stat_summary(fun = mean, geom = "point", color='#000000', fill='#000000', shape=21, size=8) +
    
    scale_x_discrete(name='', breaks=c('t.at','s.at'),
                     labels=c('Temporal','Spatial')) +
    scale_y_continuous(name='P2 Amplitude (?V)', breaks=seq(0,25,5), expand=c(0,0)) +
    coord_cartesian(ylim=c(-3, 26)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 2,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=12, colour = "black"),
          axis.text.y=element_text(size=12, colour = "black"),
          axis.title.x = element_text(size=16, vjust=0.1),
          axis.title.y=element_text(size=16, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_5d_ERP_dom_int_interact_PART1.svg',sep=""), width=4.5,height=7)
  }
  
  # SPATIAL INT TASK
  tplot = ERP_peak_df %>% group_by(subj, domint) %>% 
    filter(., (domint=="t.as" | domint=="s.as")) %>% summarise(sAvg = mean(peak)) %>%
    ggplot(data = .) +
    aes(x=domint, y=sAvg) +
    # Lines connecting individual points
    geom_line(aes(group=subj), color='black', size=1.5) +
    geom_point(fill='#CCCCCC', shape=21, size=5, stroke=1,
               alpha=1, color='#b6b6b6') +
    # Average and SEM
    stat_summary(fun.data = mean_se, geom = "errorbar", color='#000000', size=1, width=0.15) +
    stat_summary(fun = mean, geom = "point", color='#000000', fill='#000000', shape=21, size=8) +
    
    scale_x_discrete(name='', breaks=c('t.as','s.as'),
                     labels=c('Temporal','Spatial')) +
    scale_y_continuous(name='P2 Amplitude (?V)', breaks=seq(0,25,5), expand=c(0,0)) +
    coord_cartesian(ylim=c(-3, 26)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 2,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=12, colour = "black"),
          axis.text.y=element_text(size=12, colour = "black"),
          axis.title.x = element_text(size=16, vjust=0.1),
          axis.title.y=element_text(size=16, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_5d_ERP_dom_int_interact_PART2.svg',sep=""), width=4.5,height=7)
  }
}


# _____________________________________________________________________________________________________

# 6B and ^D) Alpha power timecourses (averaged across WM domain)
if (genPlots[6]==1) {
  # Auditory WM
  tplot = ggplot(data=alpha_df_domCol, aes(x=time)) +
    
    # Layer 1: error clouds for each point
    geom_ribbon(aes(ymin=a_none-a_none_sem, ymax=a_none+a_none_sem), fill=condColors_INT[1], alpha=0.2) +
    geom_ribbon(aes(ymin=a_as-a_as_sem, ymax=a_as+a_as_sem), fill=condColors_WM[2], alpha=0.2) +
    geom_ribbon(aes(ymin=a_at-a_at_sem, ymax=a_at+a_at_sem), fill=condColors_WM[1], alpha=0.2) +
    
    # Layer 2: lines connecting points
    geom_line(aes(y=a_none), color=condColors_INT[1], linetype=1, size=2) +
    geom_line(aes(y=a_as), color=condColors_WM[2], linetype=1, size=2) +
    geom_line(aes(y=a_at), color=condColors_WM[1], linetype=1, size=2) +
    
    scale_x_continuous(name='Time (sec)', breaks=seq(0.5, 4.5, by=0.5), expand=c(0,0)) +
    scale_y_continuous(name='Alpha Power (Prop. re: baseline)', breaks=seq(-3, 1.5, by=0.5), expand=c(0,0)) +
    coord_cartesian(xlim=c(0.5, 4.85), ylim=c(-3, 1.75)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 1,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=32, colour = "black"),
          axis.text.y=element_text(size=32, colour = "black"),
          axis.title.x = element_text(size=40, vjust=0.1),
          axis.title.y=element_text(size=40, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_6a_alpha_aWM.svg',sep=""), width=8, height=8, units="in")
  }
  
  # Visual WM
  tplot = ggplot(data=alpha_df_domCol, aes(x=time)) +
    
    # Layer 1: error clouds for each point
    geom_ribbon(aes(ymin=v_none-v_none_sem, ymax=v_none+v_none_sem), fill=condColors_INT[1], alpha=0.2) +
    geom_ribbon(aes(ymin=v_as-v_as_sem, ymax=v_as+v_as_sem), fill=condColors_WM[2], alpha=0.2) +
    geom_ribbon(aes(ymin=v_at-v_at_sem, ymax=v_at+v_at_sem), fill=condColors_WM[1], alpha=0.2) +
    
    # Layer 2: lines connecting points
    geom_line(aes(y=v_none), color=condColors_INT[1], linetype=1, size=2) +
    geom_line(aes(y=v_as), color=condColors_WM[2], linetype=1, size=2) +
    geom_line(aes(y=v_at), color=condColors_WM[1], linetype=1, size=2) +
    
    scale_x_continuous(name='Time (sec)', breaks=seq(0.5, 4.5, by=0.5), expand=c(0,0)) +
    scale_y_continuous(name='Alpha Power (Prop. re: baseline)', breaks=seq(-3, 1.5, by=0.5), expand=c(0,0)) +
    coord_cartesian(xlim=c(0.5, 4.85), ylim=c(-3, 1.75)) +
    theme(panel.background = element_blank(),
          panel.grid.major.y = element_blank(),
          panel.grid.major.x = element_blank(),
          aspect.ratio = 1,
          panel.border = element_rect(colour = "black", fill = NA, size=2),
          plot.margin=unit(c(2,2,2,2), "lines"),
          legend.direction = 'vertical',
          axis.text.x = element_text(size=32, colour = "black"),
          axis.text.y=element_text(size=32, colour = "black"),
          axis.title.x = element_text(size=40, vjust=0.1),
          axis.title.y=element_text(size=40, vjust=0.2))
  
  windows()
  print(tplot)
  
  if (savePlots == 1) {
    ggsave(filename = paste(imgSaveDir,'ms_6b_alpha_vWM.svg',sep=""), width=8, height=8, units="in")
  }
}


# _____________________________________________________________________________________________________


